package com.rising.online.service.impl;

import com.rising.online.bean.CompilerBean;
import com.rising.online.bean.DynamicCompileJavaFileManager;
import com.rising.online.bean.DynamicCompileJavaFileObject;
import com.rising.online.service.CompilerService;
import org.springframework.stereotype.Service;

import javax.tools.*;
import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Map;

@Service
public class CompilerServiceImpl implements CompilerService {

    @Override
    public boolean compile(CompilerBean compilerBean) {
        long startTime = System.currentTimeMillis();
        JavaCompiler compiler = compilerBean.getCompiler();
        DiagnosticCollector<JavaFileObject> diagnosticCollector = compilerBean.getDiagnosticCollector();
        //标准文件内容管理器
        StandardJavaFileManager standardJavaFileManager = compiler.getStandardFileManager(diagnosticCollector, null, null);
        JavaFileManager javaFileManager = new DynamicCompileJavaFileManager(standardJavaFileManager, compilerBean);
        //构造源代码对象
        JavaFileObject javaFileObject = new DynamicCompileJavaFileObject(compilerBean.getFullClassName(), compilerBean.getSourceCode());
        //获取一个编译任务
        JavaCompiler.CompilationTask task = compiler.getTask(null, javaFileManager, diagnosticCollector, null, null, Arrays.asList(javaFileObject));
        boolean flag = task.call();
        long endTime = System.currentTimeMillis();
        compilerBean.setCompilerTime(endTime - startTime);
        return flag;
    }

    @Override
    public void runMain(CompilerBean compilerBean) {
        PrintStream out = System.out;
        PrintStream err = System.err;
        try {
            long startTime = System.currentTimeMillis();
            ByteArrayOutputStream outputStreamOut = new ByteArrayOutputStream();
            ByteArrayOutputStream outputStreamErr = new ByteArrayOutputStream();
            PrintStream printStreamOut = new PrintStream(outputStreamOut);
            PrintStream printStreamErr = new PrintStream(outputStreamOut);
            System.setOut(printStreamOut);
            System.setErr(printStreamErr);

            DynamicClassLoader classLoader = new DynamicClassLoader(compilerBean);
            Class<?> aClass = classLoader.findClass(compilerBean.getFullClassName());
            Method main = aClass.getMethod("main", String[].class);
            Object[] pars = new Object[]{1};
            pars[0] = new String[]{};
            main.invoke(null, pars);
            long endTime = System.currentTimeMillis();
            compilerBean.setRunTime(endTime - startTime);

            compilerBean.setRunResult(new String(outputStreamOut.toByteArray(), "utf-8"));
        } catch (Exception e) {
            compilerBean.setRunResult("运行主方法异常：" + e.getMessage());
        } finally {
            //还原默认打印的对象
            System.setOut(out);
            System.setErr(err);
        }
    }

    /**
     * 自定义类加载器
     */
    private class DynamicClassLoader extends ClassLoader {
        private CompilerBean compilerBean;
        public DynamicClassLoader(CompilerBean compilerBean) {
            this.compilerBean = compilerBean;
        }

        @Override
        protected Class<?> findClass(String name) throws ClassNotFoundException {
            Map<String, byte[]> classBytes = compilerBean.getClassBytes();
            if (classBytes != null) {
                byte[] bytes = classBytes.get(name);
                return defineClass(name, bytes, 0, bytes.length);
            }
            try {
                return ClassLoader.getSystemClassLoader().loadClass(name);
            } catch (Exception e) {
                System.out.println("类加载失败：" + e.getMessage());
                return super.findClass(name);
            }
        }
    }
}
